from multiprocessing import Pool
from multiprocessing import Process
from glob import glob
import os


def launch_one_process(scale, specified_instance_idx, cuda_idx, obj_name, custm_grasp_pose=False):
    
    
    nn_envs = 20000 if (not debug) else 100
    headless = True if (not debug) else False

    if omniWristOrnt:
        if custm_grasp_pose:
            gen_grasp_sh = f"bash scripts/gen_grasp_leap_udir.sh {cuda_idx} {scale}  True {nn_envs} task.env.objectDownfacingInitZ={init_obj_z} task.env.grasp_cache_name={hand_type}_down_{handFacingDir}_init0d38_{obj_name}_custm{additional_tag}  task.env.object.type={obj_name} task.env.customizeGraspPose=True  task.env.object.specifiedObjectIdx='{specified_instance_idx}' headless={headless} task.env.handType={hand_type} task.env.object.nnInsts=100 task.env.omniWristOrnt={omniWristOrnt} task.env.handGraspFacingDir={handFacingDir}"
        else:
            gen_grasp_sh = f"bash scripts/gen_grasp_leap_udir.sh {cuda_idx} {scale}  True {nn_envs} task.env.objectDownfacingInitZ={init_obj_z} task.env.grasp_cache_name={hand_type}_down_{handFacingDir}_init0d38_{obj_name}{additional_tag}  task.env.object.type={obj_name}  task.env.object.specifiedObjectIdx='{specified_instance_idx}' headless={headless} task.env.handType={hand_type} task.env.object.nnInsts=100 task.env.omniWristOrnt={omniWristOrnt} task.env.handGraspFacingDir={handFacingDir}"
    else:
        if custm_grasp_pose:
            gen_grasp_sh = f"bash scripts/gen_grasp_leap_udir.sh {cuda_idx} {scale}  True {nn_envs} task.env.objectDownfacingInitZ={init_obj_z} task.env.grasp_cache_name={hand_type}_down_init0d38_{obj_name}_custm{additional_tag}  task.env.object.type={obj_name} task.env.customizeGraspPose=True  task.env.object.specifiedObjectIdx='{specified_instance_idx}' headless={headless} task.env.handType={hand_type} task.env.object.nnInsts=100 task.env.omniWristOrnt={omniWristOrnt} task.env.handGraspFacingDir={handFacingDir}"
        else:
            gen_grasp_sh = f"bash scripts/gen_grasp_leap_udir.sh {cuda_idx} {scale}  True {nn_envs} task.env.objectDownfacingInitZ={init_obj_z} task.env.grasp_cache_name={hand_type}_down_init0d38_{obj_name}{additional_tag}  task.env.object.type={obj_name}  task.env.object.specifiedObjectIdx='{specified_instance_idx}' headless={headless} task.env.handType={hand_type} task.env.object.nnInsts=100 task.env.omniWristOrnt={omniWristOrnt} task.env.handGraspFacingDir={handFacingDir}"
    
    
    print(gen_grasp_sh)
    os.system(gen_grasp_sh)


# python gen_grasp_pool.py

if __name__=='__main__':
    
    debug = False
    
    
    hand_type = 'leap'
    
    obj_name = 'cylinder_default'
    
    
    additional_tag = ''
    
    omniWristOrnt = False
    omniWristOrnt = True
    
    handFacingDir = 'palm_down'
    
    
    custm_grasp_pose = True
    
    init_obj_z = 0.38
    

    asset_root = 'assets'


    if "grab" in obj_name:
        tot_obj_urdfs = sorted(glob(f'{asset_root}/grab/*.urdf'))
    elif "dexenv" in obj_name:
        tot_obj_urdfs = sorted(glob(f'{asset_root}/dexenv/*.urdf'))
    else:
        obj_category, subset_name = obj_name.split("_")
        tot_obj_urdfs = sorted(glob(f'{asset_root}/{obj_category}/{subset_name}/*.urdf'))
    
    tot_obj_inst_idxes = [ 0, 1, 2, 3, 4, 5, 6, 7, 8 ]
    
    tot_obj_urdfs = [ tot_obj_urdfs[cur_idx] for cur_idx in tot_obj_inst_idxes ]
    
    
    
    nn_obj_inst = len(tot_obj_urdfs)
    
    
    scales_list = [0.7, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86]
    
    
    cuda_lists = [0, 1, 2, 3, 4, 5, 6, 7]

    
    
    
    if debug:
        cuda_lists = [0]
    
    tot_launch_args = []
    for i_s, s in enumerate(scales_list):
        for i_obj in range(len(tot_obj_urdfs)):
            cur_cuda_idx = cuda_lists[ len(tot_launch_args) % len(cuda_lists)]
            
            cur_obj_instance_idx = tot_obj_inst_idxes[i_obj]
            
            cur_launch_arg = [s, cur_obj_instance_idx, cur_cuda_idx, obj_name, custm_grasp_pose]
            
            tot_launch_args.append(cur_launch_arg)
    print(f"Total {len(tot_launch_args)} launch args")
    
    
    if debug:
        tot_launch_args = tot_launch_args[:1]
    
    maxx_pool_size = len(cuda_lists)
    for i_st in range(0, len(tot_launch_args), maxx_pool_size):
        
        i_ed = i_st + maxx_pool_size
        i_ed = min(i_ed, len(tot_launch_args))

        cur_launch_args = tot_launch_args[i_st: i_ed]
        
        
        processes = []
        for i_p, launch_arg in enumerate(cur_launch_args):
            p = Process(target=launch_one_process, args=(launch_arg))
            processes.append(p)
            p.start()
        
        for p in processes:
            p.join()



